{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "klGNgWREsvQv"
      },
      "source": [
        "##### Copyright 2018 The TF-Agents Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nQnmcm0oI1Q-"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pmDI-h7cI0tI"
      },
      "source": [
        "# TF-Agent を使用した Deep Q Network のトレーニング\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/agents/tutorials/1_dqn_tutorial.ipynb\">     <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colabで実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/agents/tutorials/1_dqn_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/agents/tutorials/1_dqn_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード/a0}</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lsaQlK8fFQqH"
      },
      "source": [
        "## はじめに\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cKOCZlhUgXVK"
      },
      "source": [
        "この例は、Cartpole環境でTF-Agentsライブラリを使用して[ DQN（Deep Q Networks）](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)エージェントをトレーニングする方法を示しています。\n",
        "\n",
        "![Cartpole environment](https://raw.githubusercontent.com/tensorflow/agents/master/docs/tutorials/images/cartpole.png)\n",
        "\n",
        "ここでは、トレーニング、評価、データ収集のための強化学習（RL）パイプラインのすべてのコンポーネントについて説明します。\n",
        "\n",
        "このコードをライブで実行するには、上の [Google Colabで実行] リンクをクリックしてください。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1u9QVVsShC9X"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kNrNXKI7bINP"
      },
      "source": [
        "以下の依存関係をインストールしていない場合は、実行します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KEHR2Ui-lo8O"
      },
      "outputs": [],
      "source": [
        "!sudo apt-get install -y xvfb ffmpeg\n",
        "!pip install 'gym==0.10.11'\n",
        "!pip install 'imageio==2.4.0'\n",
        "!pip install PILLOW\n",
        "!pip install 'pyglet==1.3.2'\n",
        "!pip install pyvirtualdisplay\n",
        "!pip install tf-agents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sMitx5qSgJk1"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import, division, print_function\n",
        "\n",
        "import base64\n",
        "import imageio\n",
        "import IPython\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import PIL.Image\n",
        "import pyvirtualdisplay\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from tf_agents.agents.dqn import dqn_agent\n",
        "from tf_agents.drivers import dynamic_step_driver\n",
        "from tf_agents.environments import suite_gym\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.eval import metric_utils\n",
        "from tf_agents.metrics import tf_metrics\n",
        "from tf_agents.networks import q_network\n",
        "from tf_agents.policies import random_tf_policy\n",
        "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n",
        "from tf_agents.trajectories import trajectory\n",
        "from tf_agents.utils import common"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J6HsdS5GbSjd"
      },
      "outputs": [],
      "source": [
        "tf.compat.v1.enable_v2_behavior()\n",
        "\n",
        "# Set up a virtual display for rendering OpenAI gym environments.\n",
        "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NspmzG4nP3b9"
      },
      "outputs": [],
      "source": [
        "tf.version.VERSION"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LmC0NDhdLIKY"
      },
      "source": [
        "## ハイパーパラメータ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HC1kNrOsLSIZ"
      },
      "outputs": [],
      "source": [
        "num_iterations = 20000 # @param {type:\"integer\"}\n",
        "\n",
        "initial_collect_steps = 1000  # @param {type:\"integer\"} \n",
        "collect_steps_per_iteration = 1  # @param {type:\"integer\"}\n",
        "replay_buffer_max_length = 100000  # @param {type:\"integer\"}\n",
        "\n",
        "batch_size = 64  # @param {type:\"integer\"}\n",
        "learning_rate = 1e-3  # @param {type:\"number\"}\n",
        "log_interval = 200  # @param {type:\"integer\"}\n",
        "\n",
        "num_eval_episodes = 10  # @param {type:\"integer\"}\n",
        "eval_interval = 1000  # @param {type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VMsJC3DEgI0x"
      },
      "source": [
        "## 環境\n",
        "\n",
        "強化学習（RL）では、環境はタスクまたは解決すべき問題を表します。標準環境は、`tf_agents.environments`スイートを使用して TF-Agent で作成できます。TF-Agent には、OpenAI Gym、Atari、DM Control などのソースから環境を読み込むためのスイートがあります。\n",
        "\n",
        "OpenAI Gym スイートから CartPole 環境を読み込みます。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pYEz-S9gEv2-"
      },
      "outputs": [],
      "source": [
        "env_name = 'CartPole-v0'\n",
        "env = suite_gym.load(env_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IIHYVBkuvPNw"
      },
      "source": [
        "この環境をレンダリングして、どのように見えるかを確認できます。台車の上に自由に振り動く棒を立て、その棒が倒れないように台車を左右に動かすことが目標です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RlO7WIQHu_7D"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "env.reset()\n",
        "PIL.Image.fromarray(env.render())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B9_lskPOey18"
      },
      "source": [
        "`environment.step`メソッドは、環境内で`action`を取り、環境の次の観測と行動の報酬を含む`TimeStep`タプルを返します。\n",
        "\n",
        "`time_step_spec()`メソッドは、`TimeStep`タプルの仕様を返します。その`observation`属性は、観測の形状、データ型、および許容値の範囲を示します。`reward`属性は、報酬についての同じ詳細を示します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "exDv57iHfwQV"
      },
      "outputs": [],
      "source": [
        "print('Observation Spec:')\n",
        "print(env.time_step_spec().observation)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UxiSyCbBUQPi"
      },
      "outputs": [],
      "source": [
        "print('Reward Spec:')\n",
        "print(env.time_step_spec().reward)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b_lHcIcqUaqB"
      },
      "source": [
        "`action_spec()`メソッドは、有効な行動の形状、データタイプ、および許容値を返します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bttJ4uxZUQBr"
      },
      "outputs": [],
      "source": [
        "print('Action Spec:')\n",
        "print(env.action_spec())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eJCgJnx3g0yY"
      },
      "source": [
        "CartPole 環境では\n",
        "\n",
        "- `observation`は以下の 4 つの float 型の配列です。\n",
        "    - 台車の位置と速度\n",
        "    - 棒の角度位置と速度\n",
        "- `reward`はスカラーの浮動小数点値です\n",
        "- `action`は可能な値が 2 つだけのスカラー整数です。\n",
        "    - `0` — 「左に移動」\n",
        "    - `1` — 「右に移動」\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V2UGR5t_iZX-"
      },
      "outputs": [],
      "source": [
        "time_step = env.reset()\n",
        "print('Time step:')\n",
        "print(time_step)\n",
        "\n",
        "action = np.array(1, dtype=np.int32)\n",
        "\n",
        "next_time_step = env.step(action)\n",
        "print('Next time step:')\n",
        "print(next_time_step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4JSc9GviWUBK"
      },
      "source": [
        "通常、2 つの環境（トレーニング用、評価用）のみがインスタンス化されます。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N7brXNIGWXjC"
      },
      "outputs": [],
      "source": [
        "train_py_env = suite_gym.load(env_name)\n",
        "eval_py_env = suite_gym.load(env_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zuUqXAVmecTU"
      },
      "source": [
        "CartPole 環境は、ほとんどの環境と同様に純粋な Python で記述されています。これは、`TFPyEnvironment`ラッパーを使用して TensorFlow に変換されます。\n",
        "\n",
        "従来の環境の API は Numpy 配列を使用します。`TFPyEnvironment`は、これらを`Tensors`に変換して、Tensorflow エージェントおよびポリシーとの互換性を確保します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xp-Y4mD6eDhF"
      },
      "outputs": [],
      "source": [
        "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n",
        "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E9lW_OZYFR8A"
      },
      "source": [
        "## エージェント\n",
        "\n",
        "RL問題の解決に使用されるアルゴリズムは、`Agent`で表されます。TF-Agent は、以下を含むさまざまな`Agents`の標準実装を提供します。\n",
        "\n",
        "- [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)（本チュートリアルで使用）\n",
        "- [REINFORCE](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)\n",
        "- [DDPG](https://arxiv.org/pdf/1509.02971.pdf)\n",
        "- [TD3](https://arxiv.org/pdf/1802.09477.pdf)\n",
        "- [PPO](https://arxiv.org/abs/1707.06347)\n",
        "- [SAC](https://arxiv.org/abs/1801.01290)\n",
        "\n",
        "DQN エージェントは、個別の行動領域がある任意の環境で使用できます。\n",
        "\n",
        "DQN エージェントの中心にあるのが`QNetwork`です。これは、環境から観察を得て、すべての行動の`QValues`（予期される戻り値）を予測する方法を学習するニューラルネットワークモデルです。\n",
        "\n",
        "`tf_agents.networks.q_network`を使用して`QNetwork`を作成し、`observation_spec`、`action_spec`、およびモデルの非表示レイヤーの数とサイズを表すタプルを渡します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TgkdEPg_muzV"
      },
      "outputs": [],
      "source": [
        "fc_layer_params = (100,)\n",
        "\n",
        "q_net = q_network.QNetwork(\n",
        "    train_env.observation_spec(),\n",
        "    train_env.action_spec(),\n",
        "    fc_layer_params=fc_layer_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z62u55hSmviJ"
      },
      "source": [
        "`tf_agents.agents.dqn.dqn_agent`を使用して`DqnAgent`をインスタンス化します。`time_step_spec`、`action_spec`、および QNetwork に加えて、エージェントコンストラクタにもオプティマイザ（この場合、`AdamOptimizer`）、損失関数、および整数ステップカウンタが必要です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jbY4yrjTEyc9"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n",
        "\n",
        "train_step_counter = tf.Variable(0)\n",
        "\n",
        "agent = dqn_agent.DqnAgent(\n",
        "    train_env.time_step_spec(),\n",
        "    train_env.action_spec(),\n",
        "    q_network=q_net,\n",
        "    optimizer=optimizer,\n",
        "    td_errors_loss_fn=common.element_wise_squared_loss,\n",
        "    train_step_counter=train_step_counter)\n",
        "\n",
        "agent.initialize()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I0KLrEPwkn5x"
      },
      "source": [
        "## ポリシー\n",
        "\n",
        "ポリシーは、エージェントが環境で行動する方法を定義します。通常、強化学習の目標は、ポリシーが望ましい結果を生成するまで、基礎となるモデルをトレーニングすることです。\n",
        "\n",
        "このチュートリアルでは\n",
        "\n",
        "- 望ましい結果は、台車の上の棒が倒れないようにバランスを保つことです。\n",
        "- ポリシーは、`time_step`観測ごとに行動（左または右）を返します。\n",
        "\n",
        "エージェントには 2 つのポリシーが含まれています。\n",
        "\n",
        "- `agent.policy` — 評価とデプロイに使用される主なポリシー。\n",
        "- `agent.collect_policy` — データ収集に使用される補助的なポリシー。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BwY7StuMkuV4"
      },
      "outputs": [],
      "source": [
        "eval_policy = agent.policy\n",
        "collect_policy = agent.collect_policy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2Qs1Fl3dV0ae"
      },
      "source": [
        "ポリシーはエージェントとは無関係に作成できます。たとえば、`tf_agents.policies.random_tf_policy`を使用して、各`time_step`の行動をランダムに選択するポリシーを作成できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HE37-UCIrE69"
      },
      "outputs": [],
      "source": [
        "random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),\n",
        "                                                train_env.action_spec())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dOlnlRRsUbxP"
      },
      "source": [
        "ポリシーから行動を取得するには、`policy.action(time_step)`メソッドを呼び出します。`time_step`には、環境からの観測が含まれています。このメソッドは、3 つのコンポーネントを持つ名前付きタプルである`PolicyStep`を返します。\n",
        "\n",
        "- `action` — 実行する行動（ここでは`0`または`1`)\n",
        "- `state` — ステートフルポリシー（RNNベース）に使用\n",
        "- `info` — 行動のログ確率などの補助データ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5gCcpXswVAxk"
      },
      "outputs": [],
      "source": [
        "example_environment = tf_py_environment.TFPyEnvironment(\n",
        "    suite_gym.load('CartPole-v0'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D4DHZtq3Ndis"
      },
      "outputs": [],
      "source": [
        "time_step = example_environment.reset()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PRFqAUzpNaAW"
      },
      "outputs": [],
      "source": [
        "random_policy.action(time_step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "94rCXQtbUbXv"
      },
      "source": [
        "## 指標と評価\n",
        "\n",
        "ポリシーの評価に使用される最も一般的な指標は、平均リターンです。リターンは、エピソードの環境でポリシーを実行中に取得した報酬の合計です。エピソードは何回か実行され、平均リターンが生成されます。\n",
        "\n",
        "次の関数は、ポリシー、環境、およびエピソードの数を指定して、ポリシーの平均リターンを計算します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bitzHo5_UbXy"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "def compute_avg_return(environment, policy, num_episodes=10):\n",
        "\n",
        "  total_return = 0.0\n",
        "  for _ in range(num_episodes):\n",
        "\n",
        "    time_step = environment.reset()\n",
        "    episode_return = 0.0\n",
        "\n",
        "    while not time_step.is_last():\n",
        "      action_step = policy.action(time_step)\n",
        "      time_step = environment.step(action_step.action)\n",
        "      episode_return += time_step.reward\n",
        "    total_return += episode_return\n",
        "\n",
        "  avg_return = total_return / num_episodes\n",
        "  return avg_return.numpy()[0]\n",
        "\n",
        "\n",
        "# See also the metrics module for standard implementations of different metrics.\n",
        "# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_snCVvq5Z8lJ"
      },
      "source": [
        "この計算を`random_policy`で実行すると、環境のベースラインパフォーマンスが示されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9bgU6Q6BZ8Bp"
      },
      "outputs": [],
      "source": [
        "compute_avg_return(eval_env, random_policy, num_eval_episodes)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NLva6g2jdWgr"
      },
      "source": [
        "## 再生バッファ\n",
        "\n",
        "再生バッファは、環境から収集されたデータを追跡します。このチュートリアルでは最も一般的な`tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer`を使用します。\n",
        "\n",
        "コンストラクタは、収集するデータの仕様を必要とします。これは、`collect_data_spec`メソッドを使用してエージェントから入手できます。バッチサイズと最大バッファ長も必要です。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vX2zGUWJGWAl"
      },
      "outputs": [],
      "source": [
        "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n",
        "    data_spec=agent.collect_data_spec,\n",
        "    batch_size=train_env.batch_size,\n",
        "    max_length=replay_buffer_max_length)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZGNTDJpZs4NN"
      },
      "source": [
        "ほとんどのエージェントの場合、`collect_data_spec`は、`Trajectory`と呼ばれる名前付きタプルであり、観測、行動、報酬、およびその他の要素の仕様が含まれています。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_IZ-3HcqgE1z"
      },
      "outputs": [],
      "source": [
        "agent.collect_data_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sy6g1tGcfRlw"
      },
      "outputs": [],
      "source": [
        "agent.collect_data_spec._fields"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rVD5nQ9ZGo8_"
      },
      "source": [
        "## データ収集\n",
        "\n",
        "次に、環境でランダムポリシーを数ステップ実行し、データを再生バッファに記録します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wr1KSAEGG4h9"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "def collect_step(environment, policy, buffer):\n",
        "  time_step = environment.current_time_step()\n",
        "  action_step = policy.action(time_step)\n",
        "  next_time_step = environment.step(action_step.action)\n",
        "  traj = trajectory.from_transition(time_step, action_step, next_time_step)\n",
        "\n",
        "  # Add trajectory to the replay buffer\n",
        "  buffer.add_batch(traj)\n",
        "\n",
        "def collect_data(env, policy, buffer, steps):\n",
        "  for _ in range(steps):\n",
        "    collect_step(env, policy, buffer)\n",
        "\n",
        "collect_data(train_env, random_policy, replay_buffer, steps=100)\n",
        "\n",
        "# This loop is so common in RL, that we provide standard implementations. \n",
        "# For more details see the drivers module.\n",
        "# https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84z5pQJdoKxo"
      },
      "source": [
        "再生バッファは Trajectory のコレクションではありません。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4wZnLu2ViO4E"
      },
      "outputs": [],
      "source": [
        "# For the curious:\n",
        "# Uncomment to peel one of these off and inspect it.\n",
        "# iter(replay_buffer.as_dataset()).next()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TujU-PMUsKjS"
      },
      "source": [
        "エージェントは再生バッファにアクセスする必要があります。これは、エージェントにデータをフィードするイテレーション可能な`tf.data.Dataset`パイプラインを作成することにより提供されます。\n",
        "\n",
        "再生バッファの各行には、1 つの観測ステップのみが格納されます。しかし、DQN エージェントは損失を計算するためにその時点の観測と次の観測を両方必要とするので、データセットパイプラインは、バッチ内の要素ごとに2つの隣接する行をサンプリングします(`num_steps=2`)。\n",
        "\n",
        "また、このデータセットは、並列呼び出しを実行してデータをプリフェッチすることにより最適化されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ba7bilizt_qW"
      },
      "outputs": [],
      "source": [
        "# Dataset generates trajectories with shape [Bx2x...]\n",
        "dataset = replay_buffer.as_dataset(\n",
        "    num_parallel_calls=3, \n",
        "    sample_batch_size=batch_size, \n",
        "    num_steps=2).prefetch(3)\n",
        "\n",
        "\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K13AST-2ppOq"
      },
      "outputs": [],
      "source": [
        "iterator = iter(dataset)\n",
        "\n",
        "print(iterator)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Th5w5Sff0b16"
      },
      "outputs": [],
      "source": [
        "# For the curious:\n",
        "# Uncomment to see what the dataset iterator is feeding to the agent.\n",
        "# Compare this representation of replay data \n",
        "# to the collection of individual trajectories shown earlier.\n",
        "\n",
        "# iterator.next()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBc9lj9VWWtZ"
      },
      "source": [
        "## エージェントのトレーニング\n",
        "\n",
        "トレーニングループ時には、以下の 2 つが行われる必要があります。\n",
        "\n",
        "- 環境からデータを収集する\n",
        "- そのデータを使用してエージェントのニューラルネットワークをトレーニングする\n",
        "\n",
        "この例では、定期的にポリシーを評価し、その時点のスコアを出力します。\n",
        "\n",
        "以下の実行には 5 分ほどかかります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0pTbJ3PeyF-u"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "try:\n",
        "  %%time\n",
        "except:\n",
        "  pass\n",
        "\n",
        "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n",
        "agent.train = common.function(agent.train)\n",
        "\n",
        "# Reset the train step\n",
        "agent.train_step_counter.assign(0)\n",
        "\n",
        "# Evaluate the agent's policy once before training.\n",
        "avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n",
        "returns = [avg_return]\n",
        "\n",
        "for _ in range(num_iterations):\n",
        "\n",
        "  # Collect a few steps using collect_policy and save to the replay buffer.\n",
        "  for _ in range(collect_steps_per_iteration):\n",
        "    collect_step(train_env, agent.collect_policy, replay_buffer)\n",
        "\n",
        "  # Sample a batch of data from the buffer and update the agent's network.\n",
        "  experience, unused_info = next(iterator)\n",
        "  train_loss = agent.train(experience).loss\n",
        "\n",
        "  step = agent.train_step_counter.numpy()\n",
        "\n",
        "  if step % log_interval == 0:\n",
        "    print('step = {0}: loss = {1}'.format(step, train_loss))\n",
        "\n",
        "  if step % eval_interval == 0:\n",
        "    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)\n",
        "    print('step = {0}: Average Return = {1}'.format(step, avg_return))\n",
        "    returns.append(avg_return)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "68jNcA_TiJDq"
      },
      "source": [
        "## 可視化\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aO-LWCdbbOIC"
      },
      "source": [
        "### プロット\n",
        "\n",
        "`matplotlib.pyplot`を使用して、トレーニング中にポリシーがどのように改善されたかをグラフ化します。\n",
        "\n",
        "`Cartpole-v0`の 1 回のイテレーションには、200のタイムステップがあります。環境は、棒が立ったままでいる各ステップに対して`+1`の報酬を与えるので、1 つのエピソードの最大リターンは 200 です。グラフは、トレーニング中に評価されるたびに、最大値に向かって増加するリターンを示しています。（多少不安定になり、毎回単調に増加しないこともあります。）"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NxtL1mbOYCVO"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "\n",
        "iterations = range(0, num_iterations + 1, eval_interval)\n",
        "plt.plot(iterations, returns)\n",
        "plt.ylabel('Average Return')\n",
        "plt.xlabel('Iterations')\n",
        "plt.ylim(top=250)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M7-XpPP99Cy7"
      },
      "source": [
        "### 動画"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9pGfGxSH32gn"
      },
      "source": [
        "グラフは便利ですが、エージェントが環境内で実際にタスクを実行している様子を動画で視覚化するとさらに分かりやすくなります。\n",
        "\n",
        "まず、ノートブックに動画を埋め込む関数を作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ULaGr8pvOKbl"
      },
      "outputs": [],
      "source": [
        "def embed_mp4(filename):\n",
        "  \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n",
        "  video = open(filename,'rb').read()\n",
        "  b64 = base64.b64encode(video)\n",
        "  tag = '''\n",
        "  <video width=\"640\" height=\"480\" controls>\n",
        "    <source src=\"data:video/mp4;base64,{0}\" type=\"video/mp4\">\n",
        "  Your browser does not support the video tag.\n",
        "  </video>'''.format(b64.decode())\n",
        "\n",
        "  return IPython.display.HTML(tag)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9c_PH-pX4Pr5"
      },
      "source": [
        "CartPole ゲームのいくつかのエピソードをエージェントで繰り返します。基礎となる Python 環境（TensorFlow 環境ラッパーの「内部」のもの）は、環境の状態のイメージを出力する`render()`メソッドを提供します。これらは動画に収集できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "owOVWB158NlF"
      },
      "outputs": [],
      "source": [
        "def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):\n",
        "  filename = filename + \".mp4\"\n",
        "  with imageio.get_writer(filename, fps=fps) as video:\n",
        "    for _ in range(num_episodes):\n",
        "      time_step = eval_env.reset()\n",
        "      video.append_data(eval_py_env.render())\n",
        "      while not time_step.is_last():\n",
        "        action_step = policy.action(time_step)\n",
        "        time_step = eval_env.step(action_step.action)\n",
        "        video.append_data(eval_py_env.render())\n",
        "  return embed_mp4(filename)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "create_policy_eval_video(agent.policy, \"trained-agent\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "povaAOcZygLw"
      },
      "source": [
        "トレーニングされたエージェント（上記）をランダムに移動するエージェントと比較してみてください。（ランダムに移動するエージェントは、トレーニングされたエージェントと比べて上手くできません。）"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pJZIdC37yNH4"
      },
      "outputs": [],
      "source": [
        "create_policy_eval_video(random_policy, \"random-agent\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "1_dqn_tutorial.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
